import Optimizer
import random
import math

# this optimizer performs a generational genetic optimization over a given number of generations
# it uses uniform crossover
# generation size equals the number of bits in the state
class SimulatedAnnealing(Optimizer.Optimizer):
    # num_batches: how many batchs to iterate over
    # states_per_batch: how many states per batch
    def __init__(self, temperature_schedule, num_batches, states_per_batch, num_bits, number_top_states, characterizer, initial_state = None, verbose=False):
        super(SimulatedAnnealing,self).__init__(number_top_states,characterizer,verbose)
        self.verbose = verbose

        self.num_batches = num_batches
        self.states_per_batch = states_per_batch

        self.num_bits = num_bits
        self.maximum_state =  2**num_bits-1

        self.batch_counter = 0

        self.temperatureSchedule = temperature_schedule

        self.random_initialization = True
        self.current_state = 0
        if initial_state is not None:
            self.random_initialization = False
            self.current_state = initial_state
            
        self.current_value = float("-inf")
        self.next_states = []

        # calculate a threshold for a neighborhood size
        # that could being fully explored
        neighborHoodSize = 0
        finished = False
        while not finished:
            neighborHoodSize += 1
            num_neighbors = len(self.getUnexploredNeighborhood(0,neighborHoodSize))
            finished = num_neighbors > self.num_batches * self.states_per_batch
        self.neighborhood_threshold = neighborHoodSize


############### INTERFACE FUNCTIONS ###############

    def isFinished(self):
        if self.verbose:
            print "Checking if Finished"
        return self.batch_counter >= self.num_batches

    def getNextStates(self):
        if self.batch_counter != 0:
            if self.verbose:
                print "Comparing States..."
            best_neighbor_state = float("-inf")
            best_neighbor_value = float("-inf")
            for neighbor_state in self.next_states:
                neighbor_value = self.explored_states[neighbor_state]
                if neighbor_value > best_neighbor_value:
                    best_neighbor_state = neighbor_state
                    best_neighbor_value = neighbor_value
            if self.current_value < best_neighbor_value:
                if self.verbose:
                    print "Improvement Detected: Moving State"
                self.current_state = best_neighbor_state
                self.current_value = best_neighbor_value

        remaining_batches = self.num_batches - self.batch_counter
        fraction_remaining = (remaining_batches-1)/float(self.num_batches-1)
        current_temperature = self.temperatureSchedule(fraction_remaining)
        # if an inital state is provided, half the temperature so that at most half the bits can flip
        # without temperature reduction, the initial state would be meaningless as all bits can flip
        if not self.random_initialization:
            current_temperature = current_temperature / 2.0
        if self.verbose:
            print "Temperature =", current_temperature
        neighborHoodSize = self.numberOfBitFlips(current_temperature)

        if self.verbose:
            print "Getting Next States"
        self.next_states = []
        # there is a relatively decent chance that the neighborhood has been
        # fully, or nearly fully, explored. Thus we explicitly calculate
        # all possible neighbors and randomly pick from that list
        if neighborHoodSize <= self.neighborhood_threshold:
            if self.verbose:
                print "Small Neighborhood, performing explicit selection..."
            possible_states = self.getUnexploredNeighborhood(self.current_state,neighborHoodSize)

            remaining_spots = self.states_per_batch - len(self.next_states)
            while len(possible_states) <= remaining_spots:
                if self.verbose:
                    print "Adding all possible neighbors of distance", neighborHoodSize
                # if all of neighborhood can be explored, explore all of it
                self.next_states += possible_states
                remaining_spots = self.num_bits - len(self.next_states)
                # expand the neighborhoodSize
                neighborHoodSize += 1
                possible_states = self.getUnexploredNeighborhood(self.current_state,neighborHoodSize)

            if self.verbose:
                print "Randomly adding neighbors of distance", neighborHoodSize
            # fill up the rest of the batch using randomly selected states from expanded neighborhood
            self.next_states += random.sample(possible_states, remaining_spots)

        # it is impossible for the neighborhood to be fully explored
        # thus we find new states by random selection. Without exploration guarentee
        # this might not terminate
        else:
            if self.verbose:
                print "Large Neighborhood, performing random selection..."
            while len(self.next_states) < self.states_per_batch:
                possible_state = self.flipRandomBits(self.current_state,neighborHoodSize)
                if possible_state not in self.explored_states.keys():
                    self.next_states += [possible_state]

        # for state in self.next_states:
        #     print self.state2bin(self.current_state), self.state2bin(state)
        self.batch_counter += 1
        return self.next_states


############### Child Class Helper Functions ###############

    def state2bin(self,state):
        bin_state = bin(state)
        state_bit_count = len(bin_state)-2
        if state_bit_count < self.num_bits:
            missing_bits = self.num_bits - state_bit_count
            bin_state = bin_state[:2] + '0'*missing_bits + bin_state[2:]
        return bin_state

    def bin2state(self,bin_state):
        return int(bin_state,2)

    # temperature is between 0 and 100, 0 means only 1 bit flips, 100 means all
    def numberOfBitFlips(self,current_temperature):
        degrees_per_bit = 100.0/(self.num_bits)
        fractional_bits = current_temperature/degrees_per_bit
        num_flips = int(math.ceil(fractional_bits))
        #cannot flip more bits than there are
        num_flips = min(self.num_bits,num_flips)
        # make sure at least one bit flips at 0 degrees
        return max(num_flips,1)

    # flip the bit in the state indicated by the bit_idx
    def flipBit(self, bin_state, bit_idx):
        bit_idx = bit_idx + 2
        bit_value = bin_state[bit_idx]
        if bit_value == '0':
            bin_state = bin_state[:bit_idx] + '1' + bin_state[(bit_idx+1):]
        else:
            bin_state = bin_state[:bit_idx] + '0' + bin_state[(bit_idx+1):]
        return bin_state

    # flip a random set of bits in the state
    def flipRandomBits(self, state, number_of_bits):
        bin_state = self.state2bin(state)
        # randomly select which bits can flip
        bits_to_flip = random.sample(range(self.num_bits), number_of_bits)
        for bit_idx in bits_to_flip:
            # randomly flip bit
            coin_flip = random.random()
            if coin_flip >= 0.5:
                bin_state = self.flipBit(bin_state,bit_idx)
        return self.bin2state(bin_state)

    def getUnexploredNeighborhood(self,state,neighborHoodSize):
        bin_state = self.state2bin(state)
        bin_neighborhood = []
        for num_flips in range(1,neighborHoodSize+1):
            bin_neighborhood += self.getBinNeighborhood( bin_state, 0,num_flips)
        neighborhood = self.binNeighborhood2Neighborhood(bin_neighborhood)
        possible_states = []
        for neighbor in neighborhood:
            # prune all neighbors that have been, or have already selected to be, explored
            if neighbor not in self.explored_states.keys() and neighbor not in self.next_states:
                possible_states += [neighbor]
        return possible_states

    def getBinNeighborhood(self, bin_state, start_idx, num_to_flip):
        if num_to_flip == 0:
            return [bin_state]
        if num_to_flip == 1:
            neighborhood = []
            for bit_idx in range(start_idx, self.num_bits):
                new_state = self.flipBit(bin_state,bit_idx)
                neighborhood += [new_state]
            return neighborhood
        else:
            neighborhood = []
            for bit_idx in range(start_idx, self.num_bits):
                new_state = self.flipBit(bin_state,bit_idx)
                neighborhood += self.getBinNeighborhood(new_state, bit_idx+1, num_to_flip-1)
        return neighborhood


    def binNeighborhood2Neighborhood(self,binNeighborhood):
        neighborhood = []
        for bin_state in list(binNeighborhood):
            neighborhood += [int(bin_state,2)]
        return neighborhood


    @staticmethod
    def linearCooling(fraction_remaining):
        return 100*fraction_remaining

    @staticmethod
    def linearCoolingHoldZero(fraction_remaining):
        fraction_at_zero = 1/6.0
        start_temp = 100.0/(1 - fraction_at_zero)
        temp = start_temp*(fraction_remaining-fraction_at_zero)
        temp = max(temp,0)
        temp = min(temp,100)
        return temp
